# datasets/transforms.py
from __future__ import annotations

from typing import Optional, Dict, Any, Tuple

import numpy as np
from sklearn.decomposition import PCA

def normalize_total(X: np.ndarray, target_sum: float = 1e4) -> np.ndarray:
    """Library-size normalize rows to a common target sum."""
    X = X.astype(np.float32, copy=True)
    s = X.sum(axis=1, keepdims=True)
    s[s == 0.0] = 1.0
    X *= (target_sum / s)
    return X

def log1p_inplace(X: np.ndarray) -> np.ndarray:
    """Apply log1p transform in place."""
    np.log1p(X, out=X)
    return X

def hvg_simple(
    X: np.ndarray,
    n: Optional[int] = 2000,
    *,
    mean_clip: tuple[float, float] = (0.01, 99.0),
) -> tuple[np.ndarray, np.ndarray]:
    """
    Basic HVG selection: pick genes with high variance after mean clipping.
    Returns (X_reduced, mask).
    """
    if n is None or n <= 0 or n >= X.shape[1]:
        mask = np.ones(X.shape[1], dtype=bool)
        return X, mask

    means = X.mean(axis=0)
    lo, hi = np.percentile(means, mean_clip)
    keep = (means >= lo) & (means <= hi)

    var = X[:, keep].var(axis=0)
    idx_keep = np.where(keep)[0]
    top_idx_local = np.argsort(var)[::-1][:n]
    mask = np.zeros(X.shape[1], dtype=bool)
    mask[idx_keep[top_idx_local]] = True
    return X[:, mask], mask

def pca_reduce(
    X: np.ndarray,
    n_components: Optional[int] = 50,
    *,
    random_state: int = 0,
) -> tuple[np.ndarray, Dict[str, Any]]:
    """PCA to n_components; returns (X_pca, info)."""
    if n_components is None or n_components <= 0 or n_components >= X.shape[1]:
        return X, {"pca_components": None, "explained_variance_ratio": None}
    pca = PCA(n_components=n_components, svd_solver="randomized", random_state=random_state)
    Xp = pca.fit_transform(X).astype(np.float32, copy=False)
    info = {
        "pca_components": pca.components_.astype(np.float32, copy=False),
        "explained_variance_ratio": pca.explained_variance_ratio_.astype(np.float32, copy=False),
    }
    return Xp, info

def default_preprocess(
    X: np.ndarray,
    *,
    normalize: bool = True,
    log1p: bool = True,
    hvg_n: Optional[int] = 2000,
    pca_n: Optional[int] = 50,
    random_state: int = 0,
) -> tuple[np.ndarray, Dict[str, Any]]:
    """
    Standard scRNA-seq preprocessing: normalize -> log1p -> HVG -> PCA.
    Returns (X_processed, info).
    """
    info: Dict[str, Any] = {}
    Y = X.astype(np.float32, copy=True)

    if normalize:
        Y = normalize_total(Y)

    if log1p:
        Y = log1p_inplace(Y)

    Y, hvg_mask = hvg_simple(Y, n=hvg_n)
    info["hvg_mask"] = hvg_mask

    Y, pinfo = pca_reduce(Y, n_components=pca_n, random_state=random_state)
    info.update(pinfo)
    return Y, info

